from typing import List, Tuple, Dict

from math import sqrt
import numpy as np

import torch
from torch import Tensor
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import k_hop_subgraph as subgraph

from ._base import _BaseExplainer

class GNNPool(nn.Module):
    def __init__(self):
        super().__init__()

# class _BaseDecomposition(_BaseExplainer):
#     '''
#     Code adapted from Dive into Graphs (DIG)
#     Code: https://github.com/divelab/DIG
#     '''

#     def __init__(self, model: nn.Module):
#         super().__init__(model=model) # Will set self.model = model
#         # Other properties: self.L (number of layers)

#     @property
#     def __num_hops__(self):
#         if self.explain_graph:
#             return -1
#         else:
#             return self.L
    
#     def set_graph_attr(self,
#                 x: Tensor,
#                 edge_index: Tensor,
#                 **kwargs
#                 ):
#         self.num_edges = edge_index.shape[1]
#         self.num_nodes = x.shape[0]
#         self.device = x.device

#     def extract_step(self, x: Tensor, edge_index: Tensor, detach: bool = True, split_fc: bool = False, forward_kwargs: dict = None):
#         '''Gets information about every layer in the graph
#         Args:

#             forward_kwargs (tuple, optional): Additional arguments to model forward call (other than x and edge_index)
#                 (default: :obj:`None`)
#         '''

#         layer_extractor = []
#         hooks = []

#         def register_hook(module: nn.Module):
#             if not list(module.children()) or isinstance(module, MessagePassing):
#                 hooks.append(module.register_forward_hook(forward_hook))

#         def forward_hook(module: nn.Module, input: Tuple[Tensor], output: Tensor):
#             # input contains x and edge_index
#             if detach:
#                 layer_extractor.append((module, input[0].clone().detach(), output.clone().detach()))
#             else:
#                 layer_extractor.append((module, input[0], output))

#         # --- register hooks ---
#         self.model.apply(register_hook)

#         # ADDED: OWEN QUEEN --------------
#         if forward_kwargs is None:
#             _ = self.model(x, edge_index)
#         else:
#             _ = self.model(x, edge_index, **forward_kwargs)
#         # --------------------------------
#         # Remove hooks:
#         for hook in hooks:
#             hook.remove()

#         # --- divide layer sets ---

#         # print('Layer extractor', [layer_extractor[i][0] for i in range(len(layer_extractor))])

#         walk_steps = []
#         fc_steps = []
#         pool_flag = False
#         step = {'input': None, 'module': [], 'output': None}
#         for layer in layer_extractor:
#             if isinstance(layer[0], MessagePassing):
#                 if step['module']: # Append step that had previously been building
#                     walk_steps.append(step)

#                 step = {'input': layer[1], 'module': [], 'output': None}

#             elif isinstance(layer[0], GNNPool):
#                 pool_flag = True
#                 if step['module']:
#                     walk_steps.append(step)

#                 # Putting in GNNPool
#                 step = {'input': layer[1], 'module': [], 'output': None}

#             elif isinstance(layer[0], nn.Linear):
#                 if step['module']:
#                     if isinstance(step['module'][0], MessagePassing):
#                         walk_steps.append(step) # Append MessagePassing layer to walk_steps
#                     else: # Always append Linear layers to fc_steps
#                         fc_steps.append(step)

#                 step = {'input': layer[1], 'module': [], 'output': None}

#             # Also appends non-trainable layers to step (not modifying input):
#             step['module'].append(layer[0])
#             step['output'] = layer[2]

#         if step['module']:
#             if isinstance(step['module'][0], MessagePassing):
#                 walk_steps.append(step)
#             else: # Append anything to FC that is not MessagePassing at its origin
#                 # Still supports sequential layers
#                 fc_steps.append(step)
#             # print('layer', layer[0])
#             # if isinstance(layer[0], MessagePassing) or isinstance(layer[0], GNNPool):
#             #     if isinstance(layer[0], GNNPool):
#             #         pool_flag = True
#             #     if step['module'] and step['input'] is not None:
#             #         walk_steps.append(step)
#             #     step = {'input': layer[1], 'module': [], 'output': None}
#             # if pool_flag and split_fc and isinstance(layer[0], nn.Linear):
#             #     if step['module']:
#             #         fc_steps.append(step)
#             #     step = {'input': layer[1], 'module': [], 'output': None}
#             # step['module'].append(layer[0])
#             # step['output'] = layer[2]

#         for walk_step in walk_steps:
#             if hasattr(walk_step['module'][0], 'nn') and walk_step['module'][0].nn is not None:
#                 # We don't allow any outside nn during message flow process in GINs
#                 walk_step['module'] = [walk_step['module'][0]]
#             elif hasattr(walk_step['module'][0], 'lin') and walk_step['module'][0].lin is not None:
#                 walk_step['module'] = [walk_step['module'][0]]

#         # print('Walk steps', [walk_steps[i]['module'] for i in range(len(walk_steps))])
#         # print('fc steps', [fc_steps[i]['module'] for i in range(len(fc_steps))])

#         return walk_steps, fc_steps

#     def walks_pick(self,
#                    edge_index: Tensor,
#                    pick_edge_indices: List,
#                    walk_indices: List=[],
#                    num_layers=0
#                    ):
#         walk_indices_list = []
#         for edge_idx in pick_edge_indices:

#             # Adding one edge
#             walk_indices.append(edge_idx)
#             _, new_src = src, tgt = edge_index[:, edge_idx]
#             next_edge_indices = np.array((edge_index[0, :] == new_src).nonzero().view(-1))

#             # Finding next edge
#             if len(walk_indices) >= num_layers:
#                 # return one walk
#                 walk_indices_list.append(walk_indices.copy())
#             else:
#                 walk_indices_list += self.walks_pick(edge_index, next_edge_indices, walk_indices, num_layers)

#             # remove the last edge
#             walk_indices.pop(-1)

#         return walk_indices_list

class _BaseDecomposition(_BaseExplainer):

    def __init__(self, model: nn.Module):
        super().__init__(model=model) # Will set self.model = model
        # Other properties: self.L (number of layers)

    @property
    def __num_hops__(self):
        if self.explain_graph:
            return -1
        else:
            return self.L
    
    def set_graph_attr(self,
                x: Tensor,
                edge_index: Tensor,
                **kwargs
                ):
        self.num_edges = edge_index.shape[1]
        self.num_nodes = x.shape[0]
        self.device = x.device

    def extract_step(self, 
            x: Tensor, 
            edge_index: Tensor, 
            detach: bool = True, 
            split_fc: bool = False, 
            forward_kwargs: dict = {}):

        layer_extractor = []
        hooks = []

        def register_hook(module: nn.Module):
            if not list(module.children()) or isinstance(module, MessagePassing):
                hooks.append(module.register_forward_hook(forward_hook))

        def forward_hook(module: nn.Module, input: Tuple[Tensor], output: Tensor):
            # input contains x and edge_index
            if detach:
                layer_extractor.append((module, input[0].clone().detach(), output.clone().detach()))
            else:
                layer_extractor.append((module, input[0], output))

        # --- register hooks ---
        self.model.apply(register_hook)

        _ = self.model(x, edge_index, **forward_kwargs)

        for hook in hooks:
            hook.remove()

        # --- divide layer sets ---

        print('layer extractor', [layer_extractor[i][0] for i in range(len(layer_extractor))])

        walk_steps = []
        fc_steps = []
        pool_flag = False
        step = {'input': None, 'module': [], 'output': None}
        for layer in layer_extractor:
            if isinstance(layer[0], MessagePassing) or isinstance(layer[0], GNNPool):
                if isinstance(layer[0], GNNPool):
                    pool_flag = True
                if step['module'] and step['input'] is not None:
                    walk_steps.append(step)
                step = {'input': layer[1], 'module': [], 'output': None}
            if pool_flag and split_fc and isinstance(layer[0], nn.Linear):
                if step['module']:
                    fc_steps.append(step)
                step = {'input': layer[1], 'module': [], 'output': None}
            step['module'].append(layer[0])
            step['output'] = layer[2]

        for walk_step in walk_steps:
            for_nn = hasattr(walk_step['module'][0], 'nn') and walk_step['module'][0].nn is not None
            for_lin = hasattr(walk_step['module'][0], 'lin') and walk_step['module'][0].lin is not None
            if for_nn or for_lin:
                # We don't allow any outside nn during message flow process in GINs
                walk_step['module'] = [walk_step['module'][0]]

        if split_fc:
            if step['module']:
                fc_steps.append(step)
            return walk_steps, fc_steps
        else:
            fc_step = step

        return walk_steps, fc_step

    def walks_pick(self,
                   edge_index: Tensor,
                   pick_edge_indices: List,
                   walk_indices: List=[],
                   num_layers=0
                   ):
        walk_indices_list = []
        for edge_idx in pick_edge_indices:

            # Adding one edge
            walk_indices.append(edge_idx)
            _, new_src = src, tgt = edge_index[:, edge_idx].cpu()
            # import ipdb; ipdb.set_trace()
            next_edge_indices = np.array((edge_index[0, :].cpu() == new_src).nonzero().view(-1))

            # Finding next edge
            if len(walk_indices) >= num_layers:
                # return one walk
                walk_indices_list.append(walk_indices.copy())
            else:
                walk_indices_list += self.walks_pick(edge_index, next_edge_indices, walk_indices, num_layers)

            # remove the last edge
            walk_indices.pop(-1)

        return walk_indices_list
